# !pip install fbprophet
# !pip install plotly
import pandas as pd
import numpy as np
from tensorflow import keras
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.figure_factory as ff
from fbprophet import Prophet
from sklearn.metrics import r2_score
# Import the data
# Entity: Country Name
# Date: Date on which records were observed
# Cases: Number of confirmed Covid-19 cases
# Deaths: Number of confirmed Covid-19 related deaths
# Daily tests: Number of tests taken in the country
# Iso_alpha: ISO code for countries
# Month: Month of observation
covid_df = pd.read_csv('Data/covid_data.csv')
covid_df.head()
| Entity | Date | Cases | Deaths | Daily tests | iso_alpha | month | |
|---|---|---|---|---|---|---|---|
| 0 | Mexico | 2020-01-01 | 0.0 | 0.0 | 25.0 | MEX | 1 |
| 1 | Mexico | 2020-01-02 | 0.0 | 0.0 | 72.0 | MEX | 1 |
| 2 | Mexico | 2020-01-03 | 0.0 | 0.0 | 89.0 | MEX | 1 |
| 3 | Thailand | 2020-01-04 | 0.0 | 0.0 | 2.0 | THA | 1 |
| 4 | Mexico | 2020-01-04 | 0.0 | 0.0 | 45.0 | MEX | 1 |
covid_df.head(10)
| Entity | Date | Cases | Deaths | Daily tests | iso_alpha | month | |
|---|---|---|---|---|---|---|---|
| 0 | Mexico | 2020-01-01 | 0.0 | 0.0 | 25.0 | MEX | 1 |
| 1 | Mexico | 2020-01-02 | 0.0 | 0.0 | 72.0 | MEX | 1 |
| 2 | Mexico | 2020-01-03 | 0.0 | 0.0 | 89.0 | MEX | 1 |
| 3 | Thailand | 2020-01-04 | 0.0 | 0.0 | 2.0 | THA | 1 |
| 4 | Mexico | 2020-01-04 | 0.0 | 0.0 | 45.0 | MEX | 1 |
| 5 | Thailand | 2020-01-05 | 0.0 | 0.0 | 2.0 | THA | 1 |
| 6 | Mexico | 2020-01-05 | 0.0 | 0.0 | 81.0 | MEX | 1 |
| 7 | Thailand | 2020-01-06 | 0.0 | 0.0 | 5.0 | THA | 1 |
| 8 | Mexico | 2020-01-06 | 0.0 | 0.0 | 167.0 | MEX | 1 |
| 9 | Mexico | 2020-01-07 | 0.0 | 0.0 | 187.0 | MEX | 1 |
covid_df.tail(10)
| Entity | Date | Cases | Deaths | Daily tests | iso_alpha | month | |
|---|---|---|---|---|---|---|---|
| 20635 | Libya | 2020-10-31 | 60628.0 | 847.0 | 3320.0 | LBY | 10 |
| 20636 | Mozambique | 2020-10-31 | 12777.0 | 91.0 | 1012.0 | MOZ | 10 |
| 20637 | Peru | 2020-10-31 | 900180.0 | 34411.0 | 0.0 | PER | 10 |
| 20638 | Malaysia | 2020-10-31 | 30899.0 | 249.0 | 17076.0 | MYS | 10 |
| 20639 | Sweden | 2020-10-31 | 129042.0 | 5998.0 | 27613.0 | SWE | 10 |
| 20640 | Saudi Arabia | 2020-10-31 | 346880.0 | 5383.0 | 44840.0 | SAU | 10 |
| 20641 | United States | 2020-10-31 | 9047427.0 | 229708.0 | 1161427.0 | USA | 10 |
| 20642 | South Africa | 2020-10-31 | 723682.0 | 19230.0 | 22150.0 | ZAF | 10 |
| 20643 | Malawi | 2020-10-31 | 5923.0 | 184.0 | 0.0 | MWI | 10 |
| 20644 | Pakistan | 2020-10-31 | 332993.0 | 6806.0 | 21688.0 | PAK | 10 |
len(covid_df)
20645
covid_df.count()
Entity 20645 Date 20645 Cases 20645 Deaths 20645 Daily tests 20645 iso_alpha 20645 month 20645 dtype: int64
# Checking null values
covid_df.isnull().sum()
Entity 0 Date 0 Cases 0 Deaths 0 Daily tests 0 iso_alpha 0 month 0 dtype: int64
# Getting dataframe info
covid_df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 20645 entries, 0 to 20644 Data columns (total 7 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Entity 20645 non-null object 1 Date 20645 non-null object 2 Cases 20645 non-null float64 3 Deaths 20645 non-null float64 4 Daily tests 20645 non-null float64 5 iso_alpha 20645 non-null object 6 month 20645 non-null int64 dtypes: float64(3), int64(1), object(3) memory usage: 1.1+ MB
# Statistical data of the dataframe
covid_df.describe()
| Cases | Deaths | Daily tests | month | |
|---|---|---|---|---|
| count | 2.064500e+04 | 20645.000000 | 2.064500e+04 | 20645.000000 |
| mean | 1.247184e+05 | 4693.475902 | 2.248130e+04 | 6.401259 |
| std | 6.068402e+05 | 17873.361923 | 1.055780e+05 | 2.399602 |
| min | 0.000000e+00 | 0.000000 | -3.743000e+03 | 1.000000 |
| 25% | 1.086000e+03 | 11.000000 | 9.700000e+01 | 4.000000 |
| 50% | 8.698000e+03 | 196.000000 | 2.216000e+03 | 6.000000 |
| 75% | 5.545200e+04 | 1522.000000 | 1.025400e+04 | 8.000000 |
| max | 9.047427e+06 | 229708.000000 | 1.492409e+06 | 10.000000 |
# Sort the dataframe by Date
covid_df = covid_df.sort_values(by = 'Date')
covid_df
| Entity | Date | Cases | Deaths | Daily tests | iso_alpha | month | |
|---|---|---|---|---|---|---|---|
| 0 | Mexico | 2020-01-01 | 0.0 | 0.0 | 25.0 | MEX | 1 |
| 1 | Mexico | 2020-01-02 | 0.0 | 0.0 | 72.0 | MEX | 1 |
| 2 | Mexico | 2020-01-03 | 0.0 | 0.0 | 89.0 | MEX | 1 |
| 3 | Thailand | 2020-01-04 | 0.0 | 0.0 | 2.0 | THA | 1 |
| 4 | Mexico | 2020-01-04 | 0.0 | 0.0 | 45.0 | MEX | 1 |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 20585 | Australia | 2020-10-31 | 27582.0 | 907.0 | 0.0 | AUS | 10 |
| 20584 | India | 2020-10-31 | 8137119.0 | 121641.0 | 1067976.0 | IND | 10 |
| 20583 | New Zealand | 2020-10-31 | 1601.0 | 25.0 | 4401.0 | NZL | 10 |
| 20591 | France | 2020-10-31 | 1331984.0 | 36565.0 | 175333.0 | FRA | 10 |
| 20644 | Pakistan | 2020-10-31 | 332993.0 | 6806.0 | 21688.0 | PAK | 10 |
20645 rows × 7 columns
# Print the number of countries considered
covid_df['Entity'].nunique()
83
covid_df['Entity'].unique()
array(['Mexico', 'Thailand', 'Japan', 'United States', 'Vietnam',
'Switzerland', 'Nepal', 'France', 'Australia', 'Malaysia',
'Canada', 'Denmark', 'Israel', 'Czech Republic', 'Sri Lanka',
'India', 'Philippines', 'Finland', 'Italy', 'Sweden',
'United Kingdom', 'Belgium', 'South Africa', 'Guatemala', 'Iran',
'Morocco', 'Kuwait', 'Bahrain', 'Norway', 'Oman', 'Iraq',
'Austria', 'Croatia', 'Algeria', 'Pakistan', 'Romania', 'Greece',
'Iceland', 'Serbia', 'New Zealand', 'Senegal', 'Nigeria',
'Ireland', 'Ecuador', 'Portugal', 'Saudi Arabia',
'Dominican Republic', 'Indonesia', 'Bangladesh', 'Jordan',
'Tunisia', 'Chile', 'Poland', 'Togo', 'Libya', 'Slovenia',
'Hungary', 'Peru', 'Costa Rica', 'Paraguay', 'Colombia',
'Bulgaria', 'Panama', 'Bolivia', 'Jamaica', 'Turkey', 'Cuba',
'Trinidad and Tobago', 'Ghana', 'Kenya', 'Ethiopia', 'Mauritania',
'Namibia', 'Uruguay', 'Rwanda', 'Zambia', 'El Salvador',
'Madagascar', 'Zimbabwe', 'Uganda', 'Mozambique', 'Myanmar',
'Malawi'], dtype=object)
# Print the name of Countries
print('Countries on which have data are:\n')
for i in covid_df['Entity'].unique():
print(i+'\n')
Countries on which have data are: Mexico Thailand Japan United States Vietnam Switzerland Nepal France Australia Malaysia Canada Denmark Israel Czech Republic Sri Lanka India Philippines Finland Italy Sweden United Kingdom Belgium South Africa Guatemala Iran Morocco Kuwait Bahrain Norway Oman Iraq Austria Croatia Algeria Pakistan Romania Greece Iceland Serbia New Zealand Senegal Nigeria Ireland Ecuador Portugal Saudi Arabia Dominican Republic Indonesia Bangladesh Jordan Tunisia Chile Poland Togo Libya Slovenia Hungary Peru Costa Rica Paraguay Colombia Bulgaria Panama Bolivia Jamaica Turkey Cuba Trinidad and Tobago Ghana Kenya Ethiopia Mauritania Namibia Uruguay Rwanda Zambia El Salvador Madagascar Zimbabwe Uganda Mozambique Myanmar Malawi
covid_df = covid_df.sort_values(by = 'Cases')
covid_df
| Entity | Date | Cases | Deaths | Daily tests | iso_alpha | month | |
|---|---|---|---|---|---|---|---|
| 0 | Mexico | 2020-01-01 | 0.0 | 0.0 | 25.0 | MEX | 1 |
| 1178 | Thailand | 2020-03-08 | 0.0 | 0.0 | 560.0 | THA | 3 |
| 158 | Switzerland | 2020-02-01 | 0.0 | 0.0 | 19.0 | CHE | 2 |
| 981 | Libya | 2020-03-04 | 0.0 | 0.0 | 3.0 | LBY | 3 |
| 192 | Switzerland | 2020-02-02 | 0.0 | 0.0 | 11.0 | CHE | 2 |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 20244 | United States | 2020-10-27 | 8704524.0 | 225735.0 | 1229458.0 | USA | 10 |
| 20368 | United States | 2020-10-28 | 8779653.0 | 226723.0 | 1355447.0 | USA | 10 |
| 20410 | United States | 2020-10-29 | 8858024.0 | 227700.0 | 1366664.0 | USA | 10 |
| 20489 | United States | 2020-10-30 | 8946154.0 | 228668.0 | 1414156.0 | USA | 10 |
| 20641 | United States | 2020-10-31 | 9047427.0 | 229708.0 | 1161427.0 | USA | 10 |
20645 rows × 7 columns
covid_df
| Entity | Date | Cases | Deaths | Daily tests | iso_alpha | month | |
|---|---|---|---|---|---|---|---|
| 0 | Mexico | 2020-01-01 | 0.0 | 0.0 | 25.0 | MEX | 1 |
| 1178 | Thailand | 2020-03-08 | 0.0 | 0.0 | 560.0 | THA | 3 |
| 158 | Switzerland | 2020-02-01 | 0.0 | 0.0 | 19.0 | CHE | 2 |
| 981 | Libya | 2020-03-04 | 0.0 | 0.0 | 3.0 | LBY | 3 |
| 192 | Switzerland | 2020-02-02 | 0.0 | 0.0 | 11.0 | CHE | 2 |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 20244 | United States | 2020-10-27 | 8704524.0 | 225735.0 | 1229458.0 | USA | 10 |
| 20368 | United States | 2020-10-28 | 8779653.0 | 226723.0 | 1355447.0 | USA | 10 |
| 20410 | United States | 2020-10-29 | 8858024.0 | 227700.0 | 1366664.0 | USA | 10 |
| 20489 | United States | 2020-10-30 | 8946154.0 | 228668.0 | 1414156.0 | USA | 10 |
| 20641 | United States | 2020-10-31 | 9047427.0 | 229708.0 | 1161427.0 | USA | 10 |
20645 rows × 7 columns
covid_df = covid_df.sort_values(by = 'Date')
covid_df
| Entity | Date | Cases | Deaths | Daily tests | iso_alpha | month | |
|---|---|---|---|---|---|---|---|
| 0 | Mexico | 2020-01-01 | 0.0 | 0.0 | 25.0 | MEX | 1 |
| 1 | Mexico | 2020-01-02 | 0.0 | 0.0 | 72.0 | MEX | 1 |
| 2 | Mexico | 2020-01-03 | 0.0 | 0.0 | 89.0 | MEX | 1 |
| 4 | Mexico | 2020-01-04 | 0.0 | 0.0 | 45.0 | MEX | 1 |
| 3 | Thailand | 2020-01-04 | 0.0 | 0.0 | 2.0 | THA | 1 |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 20615 | Ghana | 2020-10-31 | 48055.0 | 320.0 | 0.0 | GHA | 10 |
| 20585 | Australia | 2020-10-31 | 27582.0 | 907.0 | 0.0 | AUS | 10 |
| 20565 | Croatia | 2020-10-31 | 46547.0 | 531.0 | 8741.0 | HRV | 10 |
| 20605 | Czech Republic | 2020-10-31 | 323673.0 | 3078.0 | 36287.0 | CZE | 10 |
| 20641 | United States | 2020-10-31 | 9047427.0 | 229708.0 | 1161427.0 | USA | 10 |
20645 rows × 7 columns
# Function to plot interactive plot
def interactive_plot(df, column_name, title):
fig = px.line(title = title)
for i in df['Entity'].unique():
d = df[df['Entity']== i ]
fig.add_scatter(x = d['Date'], y = d[column_name], name = i )
fig.show()
# Plot interactive chart
interactive_plot(covid_df, 'Cases', 'Number of Covid cases')
# Plot interactive chart
interactive_plot(covid_df, 'Daily tests', 'Number of daily tests')
MINI CHALLENGE #3:
interactive_plot(covid_df, 'Deaths', 'Number of deaths for all countries')
covid_df
| Entity | Date | Cases | Deaths | Daily tests | iso_alpha | month | |
|---|---|---|---|---|---|---|---|
| 0 | Mexico | 2020-01-01 | 0.0 | 0.0 | 25.0 | MEX | 1 |
| 1 | Mexico | 2020-01-02 | 0.0 | 0.0 | 72.0 | MEX | 1 |
| 2 | Mexico | 2020-01-03 | 0.0 | 0.0 | 89.0 | MEX | 1 |
| 4 | Mexico | 2020-01-04 | 0.0 | 0.0 | 45.0 | MEX | 1 |
| 3 | Thailand | 2020-01-04 | 0.0 | 0.0 | 2.0 | THA | 1 |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 20615 | Ghana | 2020-10-31 | 48055.0 | 320.0 | 0.0 | GHA | 10 |
| 20585 | Australia | 2020-10-31 | 27582.0 | 907.0 | 0.0 | AUS | 10 |
| 20565 | Croatia | 2020-10-31 | 46547.0 | 531.0 | 8741.0 | HRV | 10 |
| 20605 | Czech Republic | 2020-10-31 | 323673.0 | 3078.0 | 36287.0 | CZE | 10 |
| 20641 | United States | 2020-10-31 | 9047427.0 | 229708.0 | 1161427.0 | USA | 10 |
20645 rows × 7 columns
fig = px.choropleth(covid_df, locations = 'iso_alpha', color = 'Cases', animation_frame = 'month')
fig.show()
# Covid deaths animation from January 2020 to November 2020
fig = px.choropleth(covid_df, locations = "iso_alpha", # locations iso code
color = 'Deaths', # column representing the color itensity
hover_name = "Entity", # column to add to hover information
animation_frame = 'month') # timeframe for animation
fig.show()
# Covid testing animation from January 2020 to November 2020
fig = px.choropleth(covid_df, locations = "iso_alpha", # locations iso code
color = 'Daily tests', # column representing the color itensity
hover_name = "Entity", # column to add to hover information
animation_frame = 'month') # timeframe for animation
fig.show()
fig = px.choropleth(covid_df, locations = 'iso_alpha', color = 'Cases', hover_name = "Entity",
animation_frame = 'Date')
fig.show()
covid_df
| Entity | Date | Cases | Deaths | Daily tests | iso_alpha | month | |
|---|---|---|---|---|---|---|---|
| 0 | Mexico | 2020-01-01 | 0.0 | 0.0 | 25.0 | MEX | 1 |
| 1 | Mexico | 2020-01-02 | 0.0 | 0.0 | 72.0 | MEX | 1 |
| 2 | Mexico | 2020-01-03 | 0.0 | 0.0 | 89.0 | MEX | 1 |
| 4 | Mexico | 2020-01-04 | 0.0 | 0.0 | 45.0 | MEX | 1 |
| 3 | Thailand | 2020-01-04 | 0.0 | 0.0 | 2.0 | THA | 1 |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 20615 | Ghana | 2020-10-31 | 48055.0 | 320.0 | 0.0 | GHA | 10 |
| 20585 | Australia | 2020-10-31 | 27582.0 | 907.0 | 0.0 | AUS | 10 |
| 20565 | Croatia | 2020-10-31 | 46547.0 | 531.0 | 8741.0 | HRV | 10 |
| 20605 | Czech Republic | 2020-10-31 | 323673.0 | 3078.0 | 36287.0 | CZE | 10 |
| 20641 | United States | 2020-10-31 | 9047427.0 | 229708.0 | 1161427.0 | USA | 10 |
20645 rows × 7 columns
# We are going to focus on India data for the forcasting
# Get the data corresponding to the India only
IND_df = covid_df[ covid_df['Entity']=='India']
IND_df
| Entity | Date | Cases | Deaths | Daily tests | iso_alpha | month | |
|---|---|---|---|---|---|---|---|
| 127 | India | 2020-01-30 | 1.0 | 0.0 | 0.0 | IND | 1 |
| 140 | India | 2020-01-31 | 1.0 | 0.0 | 0.0 | IND | 1 |
| 156 | India | 2020-02-01 | 1.0 | 0.0 | 0.0 | IND | 2 |
| 190 | India | 2020-02-02 | 2.0 | 0.0 | 0.0 | IND | 2 |
| 194 | India | 2020-02-03 | 2.0 | 0.0 | 0.0 | IND | 2 |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 20302 | India | 2020-10-27 | 7946429.0 | 119502.0 | 958116.0 | IND | 10 |
| 20359 | India | 2020-10-28 | 7990322.0 | 120010.0 | 1066786.0 | IND | 10 |
| 20454 | India | 2020-10-29 | 8040203.0 | 120527.0 | 1075760.0 | IND | 10 |
| 20506 | India | 2020-10-30 | 8088851.0 | 121090.0 | 1164648.0 | IND | 10 |
| 20584 | India | 2020-10-31 | 8137119.0 | 121641.0 | 1067976.0 | IND | 10 |
275 rows × 7 columns
# Only obtain the date and cases columns
IND_df = IND_df[['Date','Cases']]
IND_df
| Date | Cases | |
|---|---|---|
| 127 | 2020-01-30 | 1.0 |
| 140 | 2020-01-31 | 1.0 |
| 156 | 2020-02-01 | 1.0 |
| 190 | 2020-02-02 | 2.0 |
| 194 | 2020-02-03 | 2.0 |
| ... | ... | ... |
| 20302 | 2020-10-27 | 7946429.0 |
| 20359 | 2020-10-28 | 7990322.0 |
| 20454 | 2020-10-29 | 8040203.0 |
| 20506 | 2020-10-30 | 8088851.0 |
| 20584 | 2020-10-31 | 8137119.0 |
275 rows × 2 columns
# reset index
IND_df.reset_index(inplace=True,drop=True)
IND_df
| Date | Cases | |
|---|---|---|
| 0 | 2020-01-30 | 1.0 |
| 1 | 2020-01-31 | 1.0 |
| 2 | 2020-02-01 | 1.0 |
| 3 | 2020-02-02 | 2.0 |
| 4 | 2020-02-03 | 2.0 |
| ... | ... | ... |
| 270 | 2020-10-27 | 7946429.0 |
| 271 | 2020-10-28 | 7990322.0 |
| 272 | 2020-10-29 | 8040203.0 |
| 273 | 2020-10-30 | 8088851.0 |
| 274 | 2020-10-31 | 8137119.0 |
275 rows × 2 columns
# These are the column names expected by fbprophet
IND_df.columns = ['ds','y']
IND_df
| ds | y | |
|---|---|---|
| 0 | 2020-01-30 | 1.0 |
| 1 | 2020-01-31 | 1.0 |
| 2 | 2020-02-01 | 1.0 |
| 3 | 2020-02-02 | 2.0 |
| 4 | 2020-02-03 | 2.0 |
| ... | ... | ... |
| 270 | 2020-10-27 | 7946429.0 |
| 271 | 2020-10-28 | 7990322.0 |
| 272 | 2020-10-29 | 8040203.0 |
| 273 | 2020-10-30 | 8088851.0 |
| 274 | 2020-10-31 | 8137119.0 |
275 rows × 2 columns
# Split the data into testing and training datasets
train , test = IND_df[ IND_df['ds'] <= '2020-09-30'], IND_df[IND_df['ds'] >= '2020-10-01']
train
| ds | y | |
|---|---|---|
| 0 | 2020-01-30 | 1.0 |
| 1 | 2020-01-31 | 1.0 |
| 2 | 2020-02-01 | 1.0 |
| 3 | 2020-02-02 | 2.0 |
| 4 | 2020-02-03 | 2.0 |
| ... | ... | ... |
| 239 | 2020-09-26 | 5903932.0 |
| 240 | 2020-09-27 | 5992532.0 |
| 241 | 2020-09-28 | 6074702.0 |
| 242 | 2020-09-29 | 6145291.0 |
| 243 | 2020-09-30 | 6225763.0 |
244 rows × 2 columns
test
| ds | y | |
|---|---|---|
| 244 | 2020-10-01 | 6312584.0 |
| 245 | 2020-10-02 | 6394068.0 |
| 246 | 2020-10-03 | 6473544.0 |
| 247 | 2020-10-04 | 6549373.0 |
| 248 | 2020-10-05 | 6623815.0 |
| 249 | 2020-10-06 | 6685082.0 |
| 250 | 2020-10-07 | 6757131.0 |
| 251 | 2020-10-08 | 6835655.0 |
| 252 | 2020-10-09 | 6906151.0 |
| 253 | 2020-10-10 | 6979423.0 |
| 254 | 2020-10-11 | 7053806.0 |
| 255 | 2020-10-12 | 7120538.0 |
| 256 | 2020-10-13 | 7175880.0 |
| 257 | 2020-10-14 | 7239389.0 |
| 258 | 2020-10-15 | 7307097.0 |
| 259 | 2020-10-16 | 7370468.0 |
| 260 | 2020-10-17 | 7432680.0 |
| 261 | 2020-10-18 | 7494551.0 |
| 262 | 2020-10-19 | 7550273.0 |
| 263 | 2020-10-20 | 7597063.0 |
| 264 | 2020-10-21 | 7651107.0 |
| 265 | 2020-10-22 | 7706946.0 |
| 266 | 2020-10-23 | 7761312.0 |
| 267 | 2020-10-24 | 7814682.0 |
| 268 | 2020-10-25 | 7864811.0 |
| 269 | 2020-10-26 | 7909959.0 |
| 270 | 2020-10-27 | 7946429.0 |
| 271 | 2020-10-28 | 7990322.0 |
| 272 | 2020-10-29 | 8040203.0 |
| 273 | 2020-10-30 | 8088851.0 |
| 274 | 2020-10-31 | 8137119.0 |
Canada_df = covid_df[covid_df['Entity']=='Canada']
Canada_df
| Entity | Date | Cases | Deaths | Daily tests | iso_alpha | month | |
|---|---|---|---|---|---|---|---|
| 78 | Canada | 2020-01-26 | 1.0 | 0.0 | 0.0 | CAN | 1 |
| 80 | Canada | 2020-01-27 | 1.0 | 0.0 | 0.0 | CAN | 1 |
| 94 | Canada | 2020-01-28 | 2.0 | 0.0 | 0.0 | CAN | 1 |
| 115 | Canada | 2020-01-29 | 3.0 | 0.0 | 0.0 | CAN | 1 |
| 130 | Canada | 2020-01-30 | 3.0 | 0.0 | 0.0 | CAN | 1 |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 20291 | Canada | 2020-10-27 | 220213.0 | 9973.0 | 80889.0 | CAN | 10 |
| 20392 | Canada | 2020-10-28 | 222887.0 | 10001.0 | 55672.0 | CAN | 10 |
| 20444 | Canada | 2020-10-29 | 225586.0 | 10032.0 | 63142.0 | CAN | 10 |
| 20556 | Canada | 2020-10-30 | 228542.0 | 10074.0 | 62843.0 | CAN | 10 |
| 20580 | Canada | 2020-10-31 | 231999.0 | 10110.0 | 65562.0 | CAN | 10 |
280 rows × 7 columns
m = Prophet()
# Create and fit the prophet model to the training data
m.fit(train)
INFO:fbprophet:Disabling yearly seasonality. Run prophet with yearly_seasonality=True to override this. INFO:fbprophet:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.
<fbprophet.forecaster.Prophet at 0x1713e18cf10>
# We are going to forecast for 31 days, so we get the dataframe contain dates which includes our training
# dates as well as 31 days into the future, for forecasting.
future = m.make_future_dataframe(periods = 31)
# Make prediction
forecast = m.predict(future)
# 'yhat' is the mean predicted values and the 'yhat_lower' and 'yhat_upper' represent the lower and upper
# predicted boundaries
forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail()
| ds | yhat | yhat_lower | yhat_upper | |
|---|---|---|---|---|
| 270 | 2020-10-27 | 8.312279e+06 | 8.188998e+06 | 8.458773e+06 |
| 271 | 2020-10-28 | 8.392427e+06 | 8.263699e+06 | 8.548087e+06 |
| 272 | 2020-10-29 | 8.471188e+06 | 8.329333e+06 | 8.632345e+06 |
| 273 | 2020-10-30 | 8.552735e+06 | 8.410187e+06 | 8.734294e+06 |
| 274 | 2020-10-31 | 8.634322e+06 | 8.472134e+06 | 8.822688e+06 |
from fbprophet.plot import plot_plotly, plot_components_plotly
# Ploting the forecasted data
plot_plotly(m, forecast)
from fbprophet.plot import add_changepoints_to_plot
# This particular feature helps us identify trend changes that are infered by the model
fig = m.plot(forecast)
a = add_changepoints_to_plot(fig.gca(), m, forecast)
#future_60 = m.make_future_dataframe(periods=60)
##forecast for 60 days
#forecast_60 = m.predict(future_60)
#forecast_60[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail()
#from fbprophet.plot import plot_plotly, plot_components_plotly
## Ploting the forecasted data for 60 days
#plot_plotly(m, forecast_60)
#fig = m.plot(forecast_60)
#a = add_changepoints_to_plot(fig.gca(), m, forecast_60)
# Access the performance of the model
score = r2_score(test['y'], forecast[forecast['ds'] >= '2020-10-01']['trend'])
print('R-Sqaure score is {}'.format(score))
R-Sqaure score is 0.8438170836019973
# Add the predicted values to the original dataframe for plotting purpose
IND_df['predicted'] = forecast['trend']
IND_df
| ds | y | predicted | |
|---|---|---|---|
| 0 | 2020-01-30 | 1.0 | -2.271097e+03 |
| 1 | 2020-01-31 | 1.0 | -2.196588e+03 |
| 2 | 2020-02-01 | 1.0 | -2.122079e+03 |
| 3 | 2020-02-02 | 2.0 | -2.047570e+03 |
| 4 | 2020-02-03 | 2.0 | -1.973062e+03 |
| ... | ... | ... | ... |
| 270 | 2020-10-27 | 7946429.0 | 8.312573e+06 |
| 271 | 2020-10-28 | 7990322.0 | 8.392960e+06 |
| 272 | 2020-10-29 | 8040203.0 | 8.473347e+06 |
| 273 | 2020-10-30 | 8088851.0 | 8.553734e+06 |
| 274 | 2020-10-31 | 8137119.0 | 8.634121e+06 |
275 rows × 3 columns
# Function to plot the forecast and the origianl values for comparison
def interactive_plot_forecasting(df, title):
fig = px.line(title = title)
for i in df.columns[1:]:
fig.add_scatter(x = df['ds'],y = df[i], name = i)
fig.show()
interactive_plot_forecasting(IND_df, 'Original Vs Predicted')